d7344785d7ba51ff432881980f79f207a820034c,datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java,SupportVectorMachineTest,testPredict,#,49
Before Change
String storageName = this.getClass().getSimpleName();
DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), configuration);
df.fit_transform(trainingData);
df.save(storageName);
SupportVectorMachine.TrainingParameters param = new SupportVectorMachine.TrainingParameters();
param.getSvmParameter().kernel_type = svm_parameter.RBF;
SupportVectorMachine instance = MLBuilder.create(param, configuration);
instance.fit(trainingData);
instance.save(storageName);
df.denormalize(trainingData);
trainingData.close();
instance.close();
df.close();
//instance = null;
//df = null;
df = MLBuilder.load(DummyXYMinMaxNormalizer.class, storageName, configuration);
instance = MLBuilder.load(SupportVectorMachine.class, storageName, configuration);
df.transform(validationData);
instance.predict(validationData);
df.denormalize(validationData);
Map<Integer, Object> expResult = new HashMap<>();
After Change
String storageName = this.getClass().getSimpleName();
MinMaxScaler.TrainingParameters nsParams = new MinMaxScaler.TrainingParameters();
nsParams.setScaleResponse(true);
MinMaxScaler numericalScaler = MLBuilder.create(nsParams, configuration);
numericalScaler.fit_transform(trainingData);
numericalScaler.save(storageName);
CornerConstraintsEncoder.TrainingParameters ceParams = new CornerConstraintsEncoder.TrainingParameters();
CornerConstraintsEncoder categoricalEncoder = MLBuilder.create(ceParams, configuration);
categoricalEncoder.fit_transform(trainingData);
categoricalEncoder.save(storageName);
SupportVectorMachine.TrainingParameters param = new SupportVectorMachine.TrainingParameters();
param.getSvmParameter().kernel_type = svm_parameter.RBF;
SupportVectorMachine instance = MLBuilder.create(param, configuration);
instance.fit(trainingData);
instance.save(storageName);
trainingData.close();
instance.close();
numericalScaler.close();
categoricalEncoder.close();
numericalScaler = MLBuilder.load(MinMaxScaler.class, storageName, configuration);
categoricalEncoder = MLBuilder.load(CornerConstraintsEncoder.class, storageName, configuration);
instance = MLBuilder.load(SupportVectorMachine.class, storageName, configuration);
numericalScaler.transform(validationData);
categoricalEncoder.transform(validationData);
instance.predict(validationData);
Map<Integer, Object> expResult = new HashMap<>();
Map<Integer, Object> result = new HashMap<>();
for(Map.Entry<Integer, Record> e : validationData.entries()) {
Integer rId = e.getKey();
Record r = e.getValue();
expResult.put(rId, r.getY());
result.put(rId, r.getYPredicted());
}
assertEquals(expResult, result);
numericalScaler.delete();
categoricalEncoder.delete();
instance.delete();
validationData.close();